// [[Rcpp::depends(RcppParallel)]]
#include <Rcpp.h>
#include <omp.h>
#include <RcppParallel.h>
using namespace RcppParallel;


struct RBFKernel : public Worker
{
  const RMatrix<double> X;
  const int c;
  const RVector<int> set_c;
  const double gamma;
  RMatrix<double> out;

  RBFKernel(const Rcpp::NumericMatrix X, const int c_, Rcpp::IntegerVector set_c_, double gamma_, Rcpp::NumericMatrix out)
    : X(X), c(c_), set_c(set_c_), gamma(gamma_), out(out) {}

  void operator()(std::size_t begin, std::size_t end) {
    int p = X.ncol();
    for (std::size_t i = begin; i < end; ++i) {
      for (std::size_t j = 0; j < c; ++j) {
        int cj = set_c[j] - 1;
        double dist = 0;
        for (int k = 0; k < p; ++k) {
          double xi = X(i, k), xj = X(cj, k);
          dist += (xi - xj) * (xi - xj);
        }
        out(i, j) = exp(-gamma * dist);
      }
    }
  }
};

// [[Rcpp::export]]
Rcpp::NumericMatrix RBF_kernel_C_parallel(Rcpp::NumericMatrix X,
                                          int c,
                                          Rcpp::IntegerVector set_c,
                                          double gamma) {
  Rcpp::NumericMatrix out(X.nrow(), c);
  RBFKernel rbfkernel(X, c, set_c, gamma, out);
  parallelFor(0, X.nrow(), rbfkernel);
  return out;
}